from dataclasses import dataclass
from typing import Optional, Tuple


@dataclass
class TrainCfg:
    # general task params
    task: str = 0
    cost_limit: float = 1
    device: str = "cuda:7"
    pretrain_path: str = "logs/harzard_world/MiniGrid-HazardWorld-B-G-v1-cost-1/cvpo_disc_fake_next_obsTrue_fake_rew-10.0_fake_weight0.1_use_shieldTrue-b00a"
    pretrain_load: bool = False
    max_episode_steps: int = 200
    thread: int = 4  # if use "cpu" to train
    seed: int = 0
    # CVPO arguments
    estep_iter_num: int = 1
    estep_kl: float = 0.02
    estep_dual_max: float = 20
    estep_dual_lr: float = 0.02
    sample_act_num: int = 16
    mstep_iter_num: int = 1
    mstep_kl: float = 0.005
    mstep_dual_max: float = 0.5
    mstep_dual_lr: float = 0.1
    actor_lr: float = 1e-4
    distill_lr: float = 3e-4
    critic_lr: float = 3e-4
    gamma: float = 0.97
    n_step: int = 1
    tau: float = 0.05
    hidden_sizes: Tuple[int, ...] = (128, 128)
    double_critic: bool = False
    conditioned_sigma: bool = True
    unbounded: bool = False
    last_layer_scale: bool = False
    disc_sample: bool = True
    llm_shield: bool = False
    use_deepseek: bool = False
    use_shield: bool = False
    use_shield_distill: bool = False
    use_fake_data: bool = True
    fake_next_obs: bool = True
    fake_done: bool = False
    int_reward: bool = False
    fake_rew: float = -10.0
    fake_cost: float = 10.0
    fake_weight: float = 1.0
    shield_prop: float = 1.0
    cost_shield_thres: float = 0.01
    actor_shield_thres: float = 0.02
    prepare_steps: int = 2000
    # collecting params
    epoch: int = 100
    episode_per_collect: int = 2
    step_per_epoch: int = 1000
    update_per_step: float = 0.4
    buffer_size: int = 200000
    worker: str = "ShmemVectorEnv"
    training_num: int = 2
    testing_num: int = 20
    # general train params
    batch_size: int = 4096
    reward_threshold: float = 10000  # for early stop purpose
    save_interval: int = 4
    deterministic_eval: bool = False
    action_scaling: bool = False
    action_bound_method: str = "clip"
    resume: bool = False  # TODO
    save_ckpt: bool = True  # set this to True to save the policy model
    verbose: bool = False
    render: bool = False
    # logger params
    logdir: str = "logs"
    project: str = "harzard_world"
    group: Optional[str] = None
    name: Optional[str] = None
    prefix: Optional[str] = "cvpo_disc"
    suffix: Optional[str] = ""


# bullet-safety-gym task default configs


@dataclass
class Bullet1MCfg(TrainCfg):
    epoch: int = 100


@dataclass
class Bullet5MCfg(TrainCfg):
    epoch: int = 500


@dataclass
class Bullet10MCfg(TrainCfg):
    epoch: int = 1000
    # Drone-Run
    # estep_kl: float = 0.001
    # mstep_kl_mu: float = 0.0002
    # mstep_kl_std: float = 0.0001


# safety gymnasium task default configs


@dataclass
class MujocoBaseCfg(TrainCfg):
    task: str = "SafetyPointCircle1Gymnasium-v0"
    epoch: int = 250
    cost_limit: float = 25
    unbounded: bool = True
    gamma: float = 0.995
    n_step: int = 3
    # collecting params
    step_per_epoch: int = 20000
    buffer_size: int = 200000


@dataclass
class Mujoco2MCfg(MujocoBaseCfg):
    epoch: int = 100


@dataclass
class Mujoco5MCfg(MujocoBaseCfg):
    epoch: int = 250
    unbounded: bool = False
    gamma: float = 0.98
    n_step: int = 3
    buffer_size: int = 40000


@dataclass
class Mujoco20MCfg(MujocoBaseCfg):
    epoch: int = 1000
    sample_act_num: int = 64


@dataclass
class Mujoco10MCfg(MujocoBaseCfg):
    epoch: int = 500
    unbounded: bool = False
    gamma: float = 0.98
    sample_act_num: int = 32
